Add ExportFriendlyMultiheadAttention for dynamic shape torch.export#2
Closed
Add ExportFriendlyMultiheadAttention for dynamic shape torch.export#2
Conversation
…tensor Create the box scale tensor directly on the target device instead of using pin_memory().to(device, non_blocking=True). This enables: - CPU-only inference (pin_memory requires CUDA) - Apple MPS inference (pin_memory not supported) - PT2 export without runtime patching The scale tensor is always exactly 4 floats (16-32 bytes). For such a small tensor, the pin_memory overhead likely exceeds any async transfer benefit. Creating the tensor directly on device avoids the CPU→GPU transfer entirely.
This adds a custom MultiheadAttention implementation that bypasses F.multi_head_attention_forward to enable torch.export with dynamic shapes (e.g., variable image H/W). The problem: nn.MultiheadAttention uses F.multi_head_attention_forward which has internal guards on sequence length (e.g., Eq(seq_len, 5184)) that fail during torch.export because the sequence length is symbolic. The solution: ExportFriendlyMultiheadAttention: - Manually projects Q, K, V using the same combined in_proj_weight - Calls F.scaled_dot_product_attention directly - Avoids all shape validation guards in F.multi_head_attention_forward Also adds replace_mha_with_export_friendly() utility function to recursively replace all nn.MultiheadAttention modules in a model. Related PyTorch issues: - pytorch/pytorch#170127 - pytorch/pytorch#124502
rbavery
commented
Jan 22, 2026
| scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype) | ||
| scale = scale.pin_memory().to(device=boxes_xyxy.device, non_blocking=True) | ||
| scale = scale.view(1, 1, 4) | ||
| scale = torch.tensor( |
Member
Author
There was a problem hiding this comment.
removes memory pinning for cpu export. see #1
rbavery
commented
Jan 22, 2026
| return super().forward(*args, **kwargs) | ||
|
|
||
|
|
||
| class ExportFriendlyMultiheadAttention(nn.Module): |
Member
Author
There was a problem hiding this comment.
while export now works, still need to validate this works when running the exported model compared to original eager mode
During torch.export with dynamic H/W dimensions, SymInt values cannot be used as dict keys. These caches prevented dynamic shape export. Changes: - position_encoding.py: Remove (H, W) keyed cache in forward() - decoder.py: Remove coord_cache dict lookup in _get_rpb_matrix() The computation is cheap (just torch.arange) so always computing is acceptable for export use cases.
Member
Author
|
duplciate of #3 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds
ExportFriendlyMultiheadAttention, a custom MultiheadAttention implementation that enablestorch.exportwith dynamic shapes (e.g., variable image H/W).Problem
When exporting SAM3 with dynamic image dimensions using
torch.export, the export fails with:This happens because
nn.MultiheadAttentioncallsF.multi_head_attention_forward, which has internal guards on sequence length for:When the sequence length is symbolic (e.g.,
H*Wwhere H/W are dynamic), these guards cannot be statically evaluated, causing export to fail.Note: The commonly suggested workaround
sdpa_kernel([SDPBackend.MATH])does NOT work for this case because the guard failure happens beforescaled_dot_product_attentionis called - it occurs inF.multi_head_attention_forward's shape validation code.Solution
ExportFriendlyMultiheadAttentionbypassesF.multi_head_attention_forwardentirely by:in_proj_weightF.scaled_dot_product_attentiondirectlyAlso includes:
from_nn_mha()classmethod to create from existingnn.MultiheadAttentionwith weight copyingreplace_mha_with_export_friendly()utility function to recursively replace all MHA modules in a modelTest Results
Usage
Related Issues